from transformers import RobertaTokenizer
import argparse
from tqdm import tqdm
import os
from glob import iglob
import jsonlines
import sys

import sys
sys.path.append(".")
from rob_baseline import *

def generate_tokens(x,tokenizer):
    max_len = 75
    ctxt = x["context"]
    question = x["question"]
    positive = x["answer"] 
    ctokens = tokenizer.tokenize(ctxt)
    qtokens = tokenizer.tokenize(question)
    ptokens = tokenizer.tokenize(positive)
    
    cmask = ["<mask>"]*len(ctokens)
    pmask = ["<mask>"]*len(ptokens)
    
    input_ids1 = ctokens + ["."]+ qtokens + pmask
    input_ids2 = cmask + ["."]+ qtokens + ptokens
    gold = ctokens + ["."]+ qtokens + ptokens
    assert len(input_ids1)==len(input_ids2)
    assert len(input_ids1)==len(gold)
    
    return gold,tokenizer.convert_tokens_to_ids(input_ids1),tokenizer.convert_tokens_to_ids(input_ids2),tokenizer.convert_tokens_to_ids(gold)

def generate_tokens_anli(x,tokenizer):
    max_len = 75
    ctxt = x["context"]
    question = x["question"]
    positive = x["answer"] 
    ctokens = tokenizer.tokenize(ctxt)
    qtokens = tokenizer.tokenize(question)
    ptokens = tokenizer.tokenize(positive)
    
    cmask = ["<mask>"]*len(ctokens)
    pmask = ["<mask>"]*len(ptokens)
    qmask = ["<mask>"]*len(qtokens)
    
    input_ids1 = ctokens + ["."]+ qtokens + pmask
    input_ids2 = cmask + ["."]+ qtokens + ptokens
    input_ids3 = ctokens + ["."] + qmask + ptokens
    gold = ctokens + ["."]+ qtokens + ptokens
    assert len(input_ids1)==len(input_ids2)
    assert len(input_ids1)==len(gold)
    
    return gold,tokenizer.convert_tokens_to_ids(input_ids1),tokenizer.convert_tokens_to_ids(input_ids2),tokenizer.convert_tokens_to_ids(input_ids3),tokenizer.convert_tokens_to_ids(gold)


if __name__ == "__main__":
    print("Converting Input QA to CMLM inputs")
    parser = argparse.ArgumentParser()
    parser.add_argument("--inpfile",
                    default=None,
                    type=str,
                    required=True,
                    help="input jsonfile")
    parser.add_argument("--out",
                    default="./",
                    type=str,
                    help="Path to the output directory, where the files will be saved")
    parser.add_argument("--bert_model",
                    default="roberta-large",
                    type=str,
                    help="Roberta Model")
    parser.add_argument("--dtype",
                    default="ktl",
                    type=str,
                    help="Roberta Model")
    args = parser.parse_args()
    
    tokenizer = RobertaTokenizer.from_pretrained(args.bert_model)
    
    out_cmlm_json_file = args.out 
    
    with jsonlines.open(out_cmlm_json_file,"w") as allfd:
        instance_reader = INSTANCE_READERS[args.dtype]()
        with jsonlines.open(args.inpfile,"r") as fd:
            for row in tqdm.tqdm(fd,"Converting:"):
                context, question, label, choices, context_with_choices = \
                instance_reader.fields_to_instance(fields=row)
                row={"context":context,"question":question,"answer":choices[label]}
                if args.dtype=="anli":
                    ppw = generate_tokens_anli(row,tokenizer)
                    outrow = {"inputids":ppw[1],"labels":ppw[-1],"gold_texts":ppw[0]}
                    allfd.write(outrow)
                    outrow = {"inputids":ppw[2],"labels":ppw[-1],"gold_texts":ppw[0]}
                    allfd.write(outrow)
                    outrow = {"inputids":ppw[3],"labels":ppw[-1],"gold_texts":ppw[0]}
                    allfd.write(outrow)
                else:
                    ppw = generate_tokens(row,tokenizer)
                    outrow = {"inputids":ppw[1],"labels":ppw[-1],"gold_texts":ppw[0]}
                    allfd.write(outrow)
                    outrow = {"inputids":ppw[2],"labels":ppw[-1],"gold_texts":ppw[0]}
                    allfd.write(outrow)
            